Safetensors
Japanese
bert
japanese-splade-v2 / README.md
hotchpotch's picture
Update README.md
07a62c4 verified
|
raw
history blame
6.27 kB
metadata
license: mit
datasets:
  - hpprc/emb
  - hotchpotch/japanese-splade-v1-hard-negatives
  - hpprc/msmarco-ja
language:
  - ja
base_model:
  - hotchpotch/japanese-splade-base-v1_5

高性能な日本語 SPLADE (Sparse Lexical and Expansion Model) モデルです。テキストからスパースベクトルへの変換デモで、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。

  • ⭐️URLの差し替えを行う

また、モデルの学習にはYAST - Yet Another SPLADE or Sparse Trainerを使っています。

利用方法

YASEM (Yet Another Splade|Sparse Embedder)

yasem を使うと、SPLADEの推論を簡単に利用できます。

pip install yasem
from yasem import SpladeEmbedder

model_name = "hotchpotch/japanese-splade-base-v2"
embedder = SpladeEmbedder(model_name)

query = "車の燃費を向上させる方法は?"
docs = [
    "急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。",
    "車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。",
    "車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
]

print(embedder.rank(query, docs, return_documents=True))

[
 { 'corpus_id': 0
 , 'score': 4.28
 , 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' }
 ,
 { 'corpus_id': 2
 , 'score': 2.47
 , 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' }
 ,
 { 'corpus_id': 1
 , 'score': 2.34
 , 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' }
]
sentences = [query] + docs

embeddings = embedder.encode(sentences)
similarity = embedder.similarity(embeddings, embeddings)

print(similarity)
[[5.19151189, 4.28027662, 2.34164901, 2.47221905],
[4.28027662, 11.64426784, 5.00328318, 2.15031016],
[2.34164901, 5.00328318, 6.05594296, 1.33752085],
[2.47221905, 2.15031016, 1.33752085, 9.39414744]]
token_values = embedder.get_token_values(embeddings[0])
print(token_values)
{
 '燃費': 1.13,
 '方法': 1.07,
 '車': 1.05,
 '高める': 0.67,
 '向上': 0.56,
 '増加': 0.52,
 '都市': 0.44,
 'ガソリン': 0.32,
 '改善': 0.30,
 ...

transformers からの利用


from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def splade_max_pooling(logits, attention_mask):
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    return max_val

tokens = tokenizer(
    sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}

with torch.no_grad():
    outputs = model(**tokens)
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])

similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
print(similarity)
tensor([
   [5.1872, 4.2792, 2.3440, 2.4680],
   [4.2792, 11.6327, 4.9983, 2.1470],
   [2.3440, 4.9983, 6.0517, 1.3377],
   [2.4680, 2.1470, 1.3377, 9.3801]
])

ベンチマークスコア

retrieval (JMTEB)

JMTEB の評価結果です。

japanese-splade-base-v2 は JMTEB をスパースベクトルで評価できるように変更したコードでの評価となっています。 なお、japanese-splade-base-v2 は jaqket, mrtydi, jagovfaqs nlp_jornal のドメインを学習していません

モデル名 jagovfaqs jaqket mrtydi nlp_journal
title_abs
nlp_journal
abs_intro
nlp_journal
title_intro
Avg
<512
Avg
ALL
japanese-splade-base-v2 0.7313 0.6986 0.5106 0.9831 0.9067 0.8026 0.7309 0.7722
GLuCoSE-base-ja-v2 0.6979 0.6729 0.4186 0.9511 0.9029 0.7580 0.6851 0.7336
multilingual-e5-large 0.7030 0.5878 0.4363 0.9470 0.8600 0.7248 0.6685 0.7098
ruri-large 0.7668 0.6174 0.3803 0.9658 0.8712 0.7797 0.6826 0.7302
jinaai/jina-embeddings-v3 0.7150 0.4648 0.4545 0.9562 0.9843 0.9385 0.6476 0.7522
sarashina-embedding-v1-1b 0.7168 0.7279 0.4195 0.9696 0.9394 0.8833 0.7085 0.7761
OpenAI/text-embedding-3-large 0.7241 0.4821 0.3488 0.9655 0.9933 0.9547 0.6301 0.7448

学習元データセット