Spaces:
Sleeping
Sleeping
Delete model_loading.py
Browse files- model_loading.py +0 -59
model_loading.py
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
import numpy as np
|
4 |
-
import pickle
|
5 |
-
import torch
|
6 |
-
import transformers
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from PIL import Image
|
9 |
-
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
|
10 |
-
import json
|
11 |
-
import pickle
|
12 |
-
import gradio as gr
|
13 |
-
|
14 |
-
|
15 |
-
# XLM model functions
|
16 |
-
from multilingual_clip import pt_multilingual_clip
|
17 |
-
import transformers
|
18 |
-
|
19 |
-
|
20 |
-
# Our model definition
|
21 |
-
|
22 |
-
class MultilingualClipEdited(torch.nn.Module):
|
23 |
-
def __init__(self, model_name, tokenizer_name, head_name, weights_dir='data/weights/', cache_dir=None,in_features=None,out_features=None):
|
24 |
-
super().__init__()
|
25 |
-
self.model_name = model_name
|
26 |
-
self.tokenizer_name = tokenizer_name
|
27 |
-
self.head_path = weights_dir + head_name
|
28 |
-
|
29 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir)
|
30 |
-
self.transformer = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
|
31 |
-
self.clip_head = torch.nn.Linear(in_features=in_features, out_features=out_features)
|
32 |
-
self._load_head()
|
33 |
-
|
34 |
-
def forward(self, txt):
|
35 |
-
txt_tok = self.tokenizer(txt, padding=True, return_tensors='pt')
|
36 |
-
embs = self.transformer(**txt_tok)[0]
|
37 |
-
att = txt_tok['attention_mask']
|
38 |
-
embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]
|
39 |
-
return self.clip_head(embs)
|
40 |
-
|
41 |
-
def _load_head(self):
|
42 |
-
with open(self.head_path, 'rb') as f:
|
43 |
-
lin_weights = pickle.loads(f.read())
|
44 |
-
self.clip_head.weight = torch.nn.Parameter(torch.tensor(lin_weights[0]).float().t())
|
45 |
-
self.clip_head.bias = torch.nn.Parameter(torch.tensor(lin_weights[1]).float())
|
46 |
-
|
47 |
-
AVAILABLE_MODELS = {
|
48 |
-
'bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M':{
|
49 |
-
'model_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M',
|
50 |
-
'tokenizer_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M',
|
51 |
-
'head_name': 'arabertv2-vit-B-16-siglibheads_of_the_model_arabertv2-ViT-B-16-SigLIP-512-155_.pickle'
|
52 |
-
},
|
53 |
-
|
54 |
-
}
|
55 |
-
|
56 |
-
|
57 |
-
def load_model(name, cache_dir=None,in_features=None,out_features=None):
|
58 |
-
config = AVAILABLE_MODELS[name]
|
59 |
-
return MultilingualClipEdited(**config, cache_dir=cache_dir, in_features= in_features, out_features=out_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|