--- license: mit datasets: - hpprc/emb - hotchpotch/japanese-splade-v1-hard-negatives - hpprc/msmarco-ja language: - ja base_model: - hotchpotch/japanese-splade-base-v1_5 --- 高性能な日本語 [SPLADE](https://github.com/naver/splade) (Sparse Lexical and Expansion Model) モデルです。[テキストからスパースベクトルへの変換デモ](https://huggingface.co/spaces/hotchpotch/japanese-splade-demo-streamlit)で、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。 - ⭐️URLの差し替えを行う また、モデルの学習には[YAST - Yet Another SPLADE or Sparse Trainer](https://github.com/hotchpotch/yast)を使っています。 # 利用方法 ## [YASEM (Yet Another Splade|Sparse Embedder)](https://github.com/hotchpotch/yasem) ```bash pip install yasem ``` ```python from yasem import SpladeEmbedder model_name = "hotchpotch/japanese-splade-base-v2" embedder = SpladeEmbedder(model_name) query = "車の燃費を向上させる方法は?" docs = [ "急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。", "車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。", "車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。", ] print(embedder.rank(query, docs, return_documents=True)) ``` ``` ``` ```json [ { 'corpus_id': 0 , 'score': 4.28 , 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' } , { 'corpus_id': 2 , 'score': 2.47 , 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' } , { 'corpus_id': 1 , 'score': 2.34 , 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' } ] ``` ```python sentences = [query] + docs embeddings = embedder.encode(sentences) similarity = embedder.similarity(embeddings, embeddings) print(similarity) ``` ```json [[5.19151189, 4.28027662, 2.34164901, 2.47221905], [4.28027662, 11.64426784, 5.00328318, 2.15031016], [2.34164901, 5.00328318, 6.05594296, 1.33752085], [2.47221905, 2.15031016, 1.33752085, 9.39414744]] ``` ```python token_values = embedder.get_token_values(embeddings[0]) print(token_values) ``` ```json { '燃費': 1.13, '方法': 1.07, '車': 1.05, '高める': 0.67, '向上': 0.56, '増加': 0.52, '都市': 0.44, 'ガソリン': 0.32, '改善': 0.30, ... ``` ## transformers からの利用 ```python from transformers import AutoModelForMaskedLM, AutoTokenizer import torch model = AutoModelForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) def splade_max_pooling(logits, attention_mask): relu_log = torch.log(1 + torch.relu(logits)) weighted_log = relu_log * attention_mask.unsqueeze(-1) max_val, _ = torch.max(weighted_log, dim=1) return max_val tokens = tokenizer( sentences, return_tensors="pt", padding=True, truncation=True, max_length=512 ) tokens = {k: v.to(model.device) for k, v in tokens.items()} with torch.no_grad(): outputs = model(**tokens) embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"]) similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0) print(similarity) ``` ```python tensor([ [5.1872, 4.2792, 2.3440, 2.4680], [4.2792, 11.6327, 4.9983, 2.1470], [2.3440, 4.9983, 6.0517, 1.3377], [2.4680, 2.1470, 1.3377, 9.3801] ]) ``` # ベンチマークスコア ## retrieval (JMTEB) [JMTEB](https://github.com/sbintuitions/JMTEB) の評価結果です。japanese-splade-base-v1 は [JMTEB をスパースベクトルで評価できるように変更したコード](https://github.com/hotchpotch/JMTEB/tree/add_splade)での評価となっています。 なお、japanese-splade-base-v2 は jaqket, mrtydi, jagovfaqs nlp_jornal のドメインを**学習していません**。 | モデル名 | jagovfaqs | jaqket | mrtydi | nlp_journal
title_abs | nlp_journal
abs_intro | nlp_journal
title_intro | Avg
<512 | Avg
ALL | | ----------------------------- | --------- | ------ | ---------- | ------------------------- | ------------------------- | --------------------------- | ------------- | ------------ | | japanese-splade-base-v2 | 0.7313 | 0.6986 | **0.5106** | **0.9831** | 0.9067 | 0.8026 | **0.7309** | **0.7722** | | GLuCoSE-base-ja-v2 | 0.6979 | 0.6729 | 0.4186 | 0.9511 | 0.9029 | 0.7580 | 0.6851 | 0.7336 | | multilingual-e5-large | 0.7030 | 0.5878 | 0.4363 | 0.9470 | 0.8600 | 0.7248 | 0.6685 | 0.7098 | | ruri-large | 0.7668 | 0.6174 | 0.3803 | 0.9658 | 0.8712 | 0.7797 | 0.6826 | 0.7302 | | jinaai/jina-embeddings-v3 | 0.7150 | 0.4648 | 0.4545 | 0.9562 | 0.9843 | 0.9385 | 0.6476 | 0.7522 | | sarashina-embedding-v1-1b | 0.7168 | 0.7279 | 0.4195 | 0.9696 | 0.9394 | 0.8833 | 0.7085 | 0.7761 | | OpenAI/text-embedding-3-large | 0.7241 | 0.4821 | 0.3488 | 0.9655 | 0.9933 | 0.9547 | 0.6301 | 0.7448 | ## 学習元データセット - [hpprc/emb](https://huggingface.co/datasets/hpprc/emb) - auto-wiki-qa - jsquad - jaquad - auto-wiki-qa-nemotron - quiz-works - quiz-no-mori - baobab-wiki-retrieval - mkqa - [hotchpotch/japanese-splade-v1-hard-negatives](https://huggingface.co/datasets/hotchpotch/japanese-splade-v1-hard-negatives) - mmarco - mqa - msmarco-ja - また英語データセットとして、MS Marcoを利用しています。