Update README.md
Browse files
README.md
CHANGED
@@ -82,9 +82,13 @@ def prepare_inputs(text, has_image=False, device='cuda'):
|
|
82 |
tokenize=False
|
83 |
)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
|
89 |
return input_ids, attention_mask
|
90 |
|
|
|
82 |
tokenize=False
|
83 |
)
|
84 |
|
85 |
+
if has_image:
|
86 |
+
text_chunks = [tokenizer(chunk).input_ids for chunk in inputs_formatted.split('<|image|>')]
|
87 |
+
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
|
88 |
+
attention_mask = torch.ones_like(input_ids).to(device)
|
89 |
+
else:
|
90 |
+
input_ids = torch.tensor(tokenizer(inputs_formatted).input_ids, dtype=torch.long).unsqueeze(0).to(device)
|
91 |
+
attention_mask = torch.ones_like(input_ids).to(device)
|
92 |
|
93 |
return input_ids, attention_mask
|
94 |
|