File size: 8,731 Bytes
fa0ad02 5423499 fa0ad02 c62178f 5f6b23e fa0ad02 cb14435 fa0ad02 0b209a5 fa0ad02 0b209a5 fa0ad02 07a62c4 cb14435 c62178f fa0ad02 cb14435 c62178f fa0ad02 07a62c4 fa0ad02 07a62c4 fa0ad02 fa1683d fa0ad02 f37f157 fa1683d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
---
license: mit
datasets:
- hpprc/emb
- hotchpotch/japanese-splade-v1-hard-negatives
- hpprc/msmarco-ja
language:
- ja
base_model:
- hotchpotch/japanese-splade-base-v1_5
---
高性能な日本語 [SPLADE](https://github.com/naver/splade) (Sparse Lexical and Expansion Model) モデルです。[テキストからスパースベクトルへの変換デモ](https://huggingface.co/spaces/hotchpotch/japanese-splade-demo-streamlit)で、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。
- ⭐️記事へのリンク
また、モデルの学習には[YAST - Yet Another SPLADE or Sparse Trainer](https://github.com/hotchpotch/yast)を使っています。
# 利用方法
## [YASEM (Yet Another Splade|Sparse Embedder)](https://github.com/hotchpotch/yasem)
[YASEM](https://github.com/hotchpotch/yasem) を利用することで、SPLADEの推論・単語トークンの確認を簡単に行えます。
```bash
pip install yasem
```
```python
from yasem import SpladeEmbedder
model_name = "hotchpotch/japanese-splade-v2"
embedder = SpladeEmbedder(model_name)
query = "車の燃費を向上させる方法は?"
docs = [
"急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。",
"車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。",
"車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
]
print(embedder.rank(query, docs, return_documents=True))
```
```python
[
{ 'corpus_id': 0
, 'score': 4.28
, 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' }
,
{ 'corpus_id': 2
, 'score': 2.47
, 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' }
,
{ 'corpus_id': 1
, 'score': 2.34
, 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' }
]
```
```python
sentences = [query] + docs
embeddings = embedder.encode(sentences)
similarity = embedder.similarity(embeddings, embeddings)
print(similarity)
```
```json
[[5.19151189, 4.28027662, 2.34164901, 2.47221905],
[4.28027662, 11.64426784, 5.00328318, 2.15031016],
[2.34164901, 5.00328318, 6.05594296, 1.33752085],
[2.47221905, 2.15031016, 1.33752085, 9.39414744]]
```
```python
token_values = embedder.get_token_values(embeddings[0])
print(token_values)
```
```python
{
'燃費': 1.13,
'方法': 1.07,
'車': 1.05,
'高める': 0.67,
'向上': 0.56,
'増加': 0.52,
'都市': 0.44,
'ガソリン': 0.32,
'改善': 0.30,
...
```
## transformers からの利用
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def splade_max_pooling(logits, attention_mask):
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
return max_val
tokens = tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = model(**tokens)
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])
similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
print(similarity)
```
```python
tensor([
[5.1872, 4.2792, 2.3440, 2.4680],
[4.2792, 11.6327, 4.9983, 2.1470],
[2.3440, 4.9983, 6.0517, 1.3377],
[2.4680, 2.1470, 1.3377, 9.3801]
])
```
# ベンチマークスコア
## retrieval (JMTEB)
[JMTEB](https://github.com/sbintuitions/JMTEB) の評価結果です。
japanese-splade-v2 は [JMTEB をスパースベクトルで評価できるように変更したコード](https://github.com/hotchpotch/JMTEB/tree/add_splade)での評価となっています。
なお、japanese-splade-v2 は JMTEB タスクである jaqket(や派生のjaqra), mrtydi(と派生のmiracl), jagovfaqs, nlp_jornal のデータセットのtrain,dev, testなどのデータは **学習に利用していません**。
| モデル名 | jagovfaqs | jaqket | mrtydi | nlp_journal <br>title_abs | nlp_journal <br>abs_intro | nlp_journal <br>title_intro | Avg <br><512 | Avg <br>ALL |
| ----------------------------- | --------- | ------ | ---------- | ------------------------- | ------------------------- | --------------------------- | ------------- | ------------ |
| japanese-splade-v2 | 0.7313 | 0.6986 | **0.5106** | **0.9831** | 0.9067 | 0.8026 | **0.7309** | 0.7722 |
| japanese-splade-base-v1 | 0.6499 | 0.6992 | 0.4365 | 0.8967 | 0.9766 | 0.8203 | 0.6906 | 0.7465 |
| GLuCoSE-base-ja-v2 | 0.6979 | 0.6729 | 0.4186 | 0.9511 | 0.9029 | 0.7580 | 0.6851 | 0.7336 |
| multilingual-e5-large | 0.7030 | 0.5878 | 0.4363 | 0.9470 | 0.8600 | 0.7248 | 0.6685 | 0.7098 |
| ruri-large | **0.7668** | 0.6174 | 0.3803 | 0.9658 | 0.8712 | 0.7797 | 0.6826 | 0.7302 |
| jinaai/jina-embeddings-v3 | 0.7150 | 0.4648 | 0.4545 | 0.9562 | 0.9843 | 0.9385 | 0.6476 | 0.7522 |
| sarashina-embedding-v1-1b | 0.7168 | **0.7279** | 0.4195 | 0.9696 | 0.9394 | 0.8833 | 0.7085 | **0.7761** |
| OpenAI/text-embedding-3-large | 0.7241 | 0.4821 | 0.3488 | 0.9655 | **0.9933** | **0.9547** | 0.6301 | 0.7448 |
## スパース性
v1 ではスパース性が強すぎたので、v2 ではバランスをとったスパース性を持たせています。
- https://github.com/hotchpotch/yast/blob/main/utils/JMTEB_L0.py
で計測しています。
| Target | jaqket-query | jaqket-docs | mrtydi-query | mrtydi-docs | jagovfaqs_22k-query | jagovfaqs_22k-docs | nlp_journal_title_abs-query | nlp_journal_title_abs-docs | nlp_journal_title_intro-query | nlp_journal_title_intro-docs | nlp_journal_abs_intro-query | nlp_journal_abs_intro-docs |
|-----------------------------------------|--------------|-------------|--------------|-------------|---------------------|--------------------|-----------------------------|----------------------------|------------------------------|-----------------------------|-----------------------------|----------------------------|
| v1 | 23.3 | 146.2 | 13.8 | 89.3 | 27.9 | 73.2 | 19 | 75.2 | 19 | 95.7 | 75.3 | 95.7 |
| v1-mmarco-only | 38.9 | 231.8 | 20.5 | 100.4 | 43.4 | 97.9 | 26.4 | 126.9 | 26.4 | 182 | 127.2 | 182 |
| v1_5 | 36.7 | 268.7 | 22.8 | 237.6 | 47.9 | 237.3 | 34.9 | 225.6 | 34.9 | 235.2 | 224.5 | 235.2 |
| v2 | 29.8 | 379.6 | 19.4 | 176.4 | 42 | 189.8 | 29 | 235.8 | 29 | 304.9 | 233.8 | 304.9 |
# 学習元データセット
- [hpprc/emb](https://huggingface.co/datasets/hpprc/emb)
- auto-wiki-qa
- jsquad
- jaquad
- auto-wiki-qa-nemotron
- quiz-works
- quiz-no-mori
- baobab-wiki-retrieval
- mkqa
- [hotchpotch/japanese-splade-v1-hard-negatives](https://huggingface.co/datasets/hotchpotch/japanese-splade-v1-hard-negatives)
- mmarco
- mqa
- msmarco-ja
- [hotchpotch/mmarco-hard-negatives-reranker-score](https://huggingface.co/datasets/hotchpotch/mmarco-hard-negatives-reranker-score)
- english
# ライセンス
MIT |