pain commited on
Commit
24682d3
·
1 Parent(s): 0a3c175

Delete model_loading.py

Browse files
Files changed (1) hide show
  1. model_loading.py +0 -59
model_loading.py DELETED
@@ -1,59 +0,0 @@
1
-
2
- import os
3
- import numpy as np
4
- import pickle
5
- import torch
6
- import transformers
7
- import torch.nn.functional as F
8
- from PIL import Image
9
- from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
10
- import json
11
- import pickle
12
- import gradio as gr
13
-
14
-
15
- # XLM model functions
16
- from multilingual_clip import pt_multilingual_clip
17
- import transformers
18
-
19
-
20
- # Our model definition
21
-
22
- class MultilingualClipEdited(torch.nn.Module):
23
- def __init__(self, model_name, tokenizer_name, head_name, weights_dir='data/weights/', cache_dir=None,in_features=None,out_features=None):
24
- super().__init__()
25
- self.model_name = model_name
26
- self.tokenizer_name = tokenizer_name
27
- self.head_path = weights_dir + head_name
28
-
29
- self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir)
30
- self.transformer = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
31
- self.clip_head = torch.nn.Linear(in_features=in_features, out_features=out_features)
32
- self._load_head()
33
-
34
- def forward(self, txt):
35
- txt_tok = self.tokenizer(txt, padding=True, return_tensors='pt')
36
- embs = self.transformer(**txt_tok)[0]
37
- att = txt_tok['attention_mask']
38
- embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]
39
- return self.clip_head(embs)
40
-
41
- def _load_head(self):
42
- with open(self.head_path, 'rb') as f:
43
- lin_weights = pickle.loads(f.read())
44
- self.clip_head.weight = torch.nn.Parameter(torch.tensor(lin_weights[0]).float().t())
45
- self.clip_head.bias = torch.nn.Parameter(torch.tensor(lin_weights[1]).float())
46
-
47
- AVAILABLE_MODELS = {
48
- 'bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M':{
49
- 'model_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M',
50
- 'tokenizer_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M',
51
- 'head_name': 'arabertv2-vit-B-16-siglibheads_of_the_model_arabertv2-ViT-B-16-SigLIP-512-155_.pickle'
52
- },
53
-
54
- }
55
-
56
-
57
- def load_model(name, cache_dir=None,in_features=None,out_features=None):
58
- config = AVAILABLE_MODELS[name]
59
- return MultilingualClipEdited(**config, cache_dir=cache_dir, in_features= in_features, out_features=out_features)