hidehisa-arai
commited on
update
Browse files- README.md +12 -7
- modeling_japanese_clip.py +12 -3
README.md
CHANGED
@@ -7,7 +7,6 @@ tags:
|
|
7 |
- clip
|
8 |
- japanese-clip
|
9 |
---
|
10 |
-
|
11 |
# recruit-jp/japanese-clip-vit-b-32-roberta-base
|
12 |
|
13 |
## Overview
|
@@ -41,17 +40,19 @@ pip install pillow requests transformers torch torchvision sentencepiece
|
|
41 |
```python
|
42 |
import io
|
43 |
import requests
|
44 |
-
from PIL import Image
|
45 |
|
46 |
import torch
|
47 |
import torchvision
|
|
|
48 |
from transformers import AutoTokenizer, AutoModel
|
49 |
|
|
|
50 |
model_name = "recruit-jp/japanese-clip-vit-b-32-roberta-base"
|
51 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
53 |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
|
54 |
|
|
|
55 |
def _convert_to_rgb(image):
|
56 |
return image.convert('RGB')
|
57 |
|
@@ -68,25 +69,29 @@ preprocess = torchvision.transforms.Compose([
|
|
68 |
def tokenize(tokenizer, texts):
|
69 |
texts = ["[CLS]" + text for text in texts]
|
70 |
encodings = [
|
|
|
71 |
tokenizer(text, max_length=77, padding="max_length", truncation=True, add_special_tokens=False)["input_ids"]
|
72 |
for text in texts
|
73 |
]
|
74 |
return torch.LongTensor(encodings)
|
75 |
|
|
|
76 |
# Run!
|
77 |
-
image = Image.open(
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
image = preprocess(image).unsqueeze(0).to(device)
|
79 |
text = tokenize(tokenizer, texts=["犬", "猫", "象"]).to(device)
|
80 |
-
|
81 |
with torch.inference_mode():
|
82 |
image_features = model.get_image_features(image)
|
83 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
84 |
-
|
85 |
text_features = model.get_text_features(input_ids=text)
|
86 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
87 |
-
|
88 |
probs = image_features @ text_features.T
|
89 |
-
|
90 |
print("Label probs:", probs.cpu().numpy()[0])
|
91 |
```
|
92 |
|
|
|
7 |
- clip
|
8 |
- japanese-clip
|
9 |
---
|
|
|
10 |
# recruit-jp/japanese-clip-vit-b-32-roberta-base
|
11 |
|
12 |
## Overview
|
|
|
40 |
```python
|
41 |
import io
|
42 |
import requests
|
|
|
43 |
|
44 |
import torch
|
45 |
import torchvision
|
46 |
+
from PIL import Image
|
47 |
from transformers import AutoTokenizer, AutoModel
|
48 |
|
49 |
+
|
50 |
model_name = "recruit-jp/japanese-clip-vit-b-32-roberta-base"
|
51 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
53 |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
|
54 |
|
55 |
+
|
56 |
def _convert_to_rgb(image):
|
57 |
return image.convert('RGB')
|
58 |
|
|
|
69 |
def tokenize(tokenizer, texts):
|
70 |
texts = ["[CLS]" + text for text in texts]
|
71 |
encodings = [
|
72 |
+
# NOTE: the maximum token length that can be fed into this model is 77
|
73 |
tokenizer(text, max_length=77, padding="max_length", truncation=True, add_special_tokens=False)["input_ids"]
|
74 |
for text in texts
|
75 |
]
|
76 |
return torch.LongTensor(encodings)
|
77 |
|
78 |
+
|
79 |
# Run!
|
80 |
+
image = Image.open(
|
81 |
+
io.BytesIO(
|
82 |
+
requests.get(
|
83 |
+
'https://images.pexels.com/photos/2253275/pexels-photo-2253275.jpeg?auto=compress&cs=tinysrgb&dpr=3&h=750&w=1260'
|
84 |
+
).content
|
85 |
+
)
|
86 |
+
)
|
87 |
image = preprocess(image).unsqueeze(0).to(device)
|
88 |
text = tokenize(tokenizer, texts=["犬", "猫", "象"]).to(device)
|
|
|
89 |
with torch.inference_mode():
|
90 |
image_features = model.get_image_features(image)
|
91 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
|
92 |
text_features = model.get_text_features(input_ids=text)
|
93 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
|
94 |
probs = image_features @ text_features.T
|
|
|
95 |
print("Label probs:", probs.cpu().numpy()[0])
|
96 |
```
|
97 |
|
modeling_japanese_clip.py
CHANGED
@@ -84,7 +84,9 @@ class AttentionalPooler(nn.Module):
|
|
84 |
):
|
85 |
super().__init__()
|
86 |
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
87 |
-
self.attn = nn.MultiheadAttention(
|
|
|
|
|
88 |
self.ln_q = norm_layer(d_model)
|
89 |
self.ln_k = norm_layer(context_dim)
|
90 |
|
@@ -92,7 +94,9 @@ class AttentionalPooler(nn.Module):
|
|
92 |
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
93 |
N = x.shape[1]
|
94 |
q = self.ln_q(self.query)
|
95 |
-
out = self.attn(
|
|
|
|
|
96 |
return out.permute(1, 0, 2) # LND -> NLD
|
97 |
|
98 |
|
@@ -187,7 +191,12 @@ class Transformer(nn.Module):
|
|
187 |
|
188 |
self.resblocks = nn.ModuleList([
|
189 |
ResidualAttentionBlock(
|
190 |
-
width,
|
|
|
|
|
|
|
|
|
|
|
191 |
for _ in range(layers)
|
192 |
])
|
193 |
|
|
|
84 |
):
|
85 |
super().__init__()
|
86 |
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
87 |
+
self.attn = nn.MultiheadAttention(
|
88 |
+
d_model, n_head, kdim=context_dim, vdim=context_dim
|
89 |
+
)
|
90 |
self.ln_q = norm_layer(d_model)
|
91 |
self.ln_k = norm_layer(context_dim)
|
92 |
|
|
|
94 |
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
95 |
N = x.shape[1]
|
96 |
q = self.ln_q(self.query)
|
97 |
+
out = self.attn(
|
98 |
+
q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False
|
99 |
+
)[0]
|
100 |
return out.permute(1, 0, 2) # LND -> NLD
|
101 |
|
102 |
|
|
|
191 |
|
192 |
self.resblocks = nn.ModuleList([
|
193 |
ResidualAttentionBlock(
|
194 |
+
width,
|
195 |
+
heads,
|
196 |
+
mlp_ratio,
|
197 |
+
ls_init_value=ls_init_value,
|
198 |
+
act_layer=act_layer,
|
199 |
+
norm_layer=norm_layer)
|
200 |
for _ in range(layers)
|
201 |
])
|
202 |
|