File size: 6,272 Bytes
fa0ad02 5f6b23e fa0ad02 07a62c4 fa0ad02 07a62c4 fa0ad02 07a62c4 fa0ad02 07a62c4 fa0ad02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
---
license: mit
datasets:
- hpprc/emb
- hotchpotch/japanese-splade-v1-hard-negatives
- hpprc/msmarco-ja
language:
- ja
base_model:
- hotchpotch/japanese-splade-base-v1_5
---
高性能な日本語 [SPLADE](https://github.com/naver/splade) (Sparse Lexical and Expansion Model) モデルです。[テキストからスパースベクトルへの変換デモ](https://huggingface.co/spaces/hotchpotch/japanese-splade-demo-streamlit)で、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。
- ⭐️URLの差し替えを行う
また、モデルの学習には[YAST - Yet Another SPLADE or Sparse Trainer](https://github.com/hotchpotch/yast)を使っています。
# 利用方法
## [YASEM (Yet Another Splade|Sparse Embedder)](https://github.com/hotchpotch/yasem)
yasem を使うと、SPLADEの推論を簡単に利用できます。
```bash
pip install yasem
```
```python
from yasem import SpladeEmbedder
model_name = "hotchpotch/japanese-splade-base-v2"
embedder = SpladeEmbedder(model_name)
query = "車の燃費を向上させる方法は?"
docs = [
"急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。",
"車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。",
"車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
]
print(embedder.rank(query, docs, return_documents=True))
```
```
```
```json
[
{ 'corpus_id': 0
, 'score': 4.28
, 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' }
,
{ 'corpus_id': 2
, 'score': 2.47
, 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' }
,
{ 'corpus_id': 1
, 'score': 2.34
, 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' }
]
```
```python
sentences = [query] + docs
embeddings = embedder.encode(sentences)
similarity = embedder.similarity(embeddings, embeddings)
print(similarity)
```
```json
[[5.19151189, 4.28027662, 2.34164901, 2.47221905],
[4.28027662, 11.64426784, 5.00328318, 2.15031016],
[2.34164901, 5.00328318, 6.05594296, 1.33752085],
[2.47221905, 2.15031016, 1.33752085, 9.39414744]]
```
```python
token_values = embedder.get_token_values(embeddings[0])
print(token_values)
```
```json
{
'燃費': 1.13,
'方法': 1.07,
'車': 1.05,
'高める': 0.67,
'向上': 0.56,
'増加': 0.52,
'都市': 0.44,
'ガソリン': 0.32,
'改善': 0.30,
...
```
## transformers からの利用
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def splade_max_pooling(logits, attention_mask):
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
return max_val
tokens = tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = model(**tokens)
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])
similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
print(similarity)
```
```python
tensor([
[5.1872, 4.2792, 2.3440, 2.4680],
[4.2792, 11.6327, 4.9983, 2.1470],
[2.3440, 4.9983, 6.0517, 1.3377],
[2.4680, 2.1470, 1.3377, 9.3801]
])
```
# ベンチマークスコア
## retrieval (JMTEB)
[JMTEB](https://github.com/sbintuitions/JMTEB) の評価結果です。
japanese-splade-base-v2 は [JMTEB をスパースベクトルで評価できるように変更したコード](https://github.com/hotchpotch/JMTEB/tree/add_splade)での評価となっています。
なお、japanese-splade-base-v2 は jaqket, mrtydi, jagovfaqs nlp_jornal のドメインを**学習していません**。
| モデル名 | jagovfaqs | jaqket | mrtydi | nlp_journal <br>title_abs | nlp_journal <br>abs_intro | nlp_journal <br>title_intro | Avg <br><512 | Avg <br>ALL |
| ----------------------------- | --------- | ------ | ---------- | ------------------------- | ------------------------- | --------------------------- | ------------- | ------------ |
| japanese-splade-base-v2 | 0.7313 | 0.6986 | **0.5106** | **0.9831** | 0.9067 | 0.8026 | **0.7309** | 0.7722 |
| GLuCoSE-base-ja-v2 | 0.6979 | 0.6729 | 0.4186 | 0.9511 | 0.9029 | 0.7580 | 0.6851 | 0.7336 |
| multilingual-e5-large | 0.7030 | 0.5878 | 0.4363 | 0.9470 | 0.8600 | 0.7248 | 0.6685 | 0.7098 |
| ruri-large | **0.7668** | 0.6174 | 0.3803 | 0.9658 | 0.8712 | 0.7797 | 0.6826 | 0.7302 |
| jinaai/jina-embeddings-v3 | 0.7150 | 0.4648 | 0.4545 | 0.9562 | 0.9843 | 0.9385 | 0.6476 | 0.7522 |
| sarashina-embedding-v1-1b | 0.7168 | **0.7279** | 0.4195 | 0.9696 | 0.9394 | 0.8833 | 0.7085 | **0.7761** |
| OpenAI/text-embedding-3-large | 0.7241 | 0.4821 | 0.3488 | 0.9655 | **0.9933** | **0.9547** | 0.6301 | 0.7448 |
## 学習元データセット
- [hpprc/emb](https://huggingface.co/datasets/hpprc/emb)
- auto-wiki-qa
- jsquad
- jaquad
- auto-wiki-qa-nemotron
- quiz-works
- quiz-no-mori
- baobab-wiki-retrieval
- mkqa
- [hotchpotch/japanese-splade-v1-hard-negatives](https://huggingface.co/datasets/hotchpotch/japanese-splade-v1-hard-negatives)
- mmarco
- mqa
- msmarco-ja
-
また英語データセットとして、MS Marcoを利用しています。 |