nielsr HF staff commited on
Commit
7051201
·
1 Parent(s): 45b9dfe

Add code example

Browse files
Files changed (1) hide show
  1. README.md +46 -6
README.md CHANGED
@@ -9,17 +9,57 @@ Vision-and-Language Transformer (ViLT) model pre-trained on GCC+SBU+COCO+VG (200
9
 
10
  Disclaimer: The team releasing ViLT did not write a model card for this model so this model card has been written by the Hugging Face team.
11
 
12
- ## Model description
13
-
14
- (to do)
15
-
16
  ## Intended uses & limitations
17
 
18
- You can use the raw model for visual question answering.
19
 
20
  ### How to use
21
 
22
- (to do)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  ## Training data
25
 
 
9
 
10
  Disclaimer: The team releasing ViLT did not write a model card for this model so this model card has been written by the Hugging Face team.
11
 
 
 
 
 
12
  ## Intended uses & limitations
13
 
14
+ You can use the raw model for masked language modeling given an image and a piece of text with [MASK] tokens.
15
 
16
  ### How to use
17
 
18
+ Here is how to use this model in PyTorch:
19
+
20
+ ```
21
+ from transformers import ViltProcessor, ViltForMaskedLM
22
+ import requests
23
+ from PIL import Image
24
+ import re
25
+
26
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
27
+ image = Image.open(requests.get(url, stream=True).raw)
28
+ text = "a bunch of [MASK] laying on a [MASK]."
29
+
30
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
31
+ model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
32
+
33
+ # prepare inputs
34
+ encoding = processor(image, text, return_tensors="pt")
35
+
36
+ # forward pass
37
+ outputs = model(**encoding)
38
+
39
+ tl = len(re.findall("\[MASK\]", text))
40
+ inferred_token = [text]
41
+
42
+ # gradually fill in the MASK tokens, one by one
43
+ with torch.no_grad():
44
+ for i in range(tl):
45
+ encoded = processor.tokenizer(inferred_token)
46
+ input_ids = torch.tensor(encoded.input_ids).to(device)
47
+ encoded = encoded["input_ids"][0][1:-1]
48
+ outputs = model(input_ids=input_ids, pixel_values=pixel_values)
49
+ mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
50
+ # only take into account text features (minus CLS and SEP token)
51
+ mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
52
+ mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
53
+ # only take into account text
54
+ mlm_values[torch.tensor(encoded) != 103] = 0
55
+ select = mlm_values.argmax().item()
56
+ encoded[select] = mlm_ids[select].item()
57
+ inferred_token = [processor.decode(encoded)]
58
+
59
+ selected_token = ""
60
+ encoded = processor.tokenizer(inferred_token)
61
+ processor.decode(encoded.input_ids[0], skip_special_tokens=True)
62
+ ```
63
 
64
  ## Training data
65