We initialize SPLADE-japanese from tohoku-nlp/bert-base-japanese-v2. This model is trained on mMARCO Japanese dataset.

from transformers import AutoModelForMaskedLM,AutoTokenizer
import torch
import numpy as np

model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanese") 
tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanese")

query = "筑波大学では何の研究が行われているか?"

def encode_query(query, tokenizer, model):
    encoded_input = tokenizer(query, return_tensors="pt")
    with torch.no_grad():
        output = model(**encoded_input, return_dict=True).logits
    aggregated_output, _ = torch.max(torch.log(1 + torch.relu(output)) * encoded_input['attention_mask'].unsqueeze(-1), dim=1)
    return aggregated_output

def get_topk_tokens(reps, vocab_dict, topk):
    topk_values, topk_indices = torch.topk(reps, topk, dim=1)
    values = np.rint(topk_values.numpy() * 100).astype(int)
    dict_splade = {vocab_dict[id_token.item()]: int(value_token) for id_token, value_token in zip(topk_indices[0], values[0]) if value_token > 0}
    return dict_splade


vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
topk = len(vocab_dict) // 1000 


model_output = encode_query(query, tokenizer, model)


dict_splade = get_topk_tokens(model_output, vocab_dict, topk)


for token, value in dict_splade.items():
    print(token, value)

output

に 250
が 248
は 247
の 247
、 244
と 240
を 239
。 239
も 238
で 237
から 221
や 219
な 206
筑波 204
( 204
・ 202
て 197
へ 191
にて 189
など 188
) 186
まで 184
た 182
この 171
- 170
「 170
より 166
その 165
: 163
」 161
Downloads last month
8
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train aken12/splade-japanese