We initialize SPLADE-japanese from tohoku-nlp/bert-base-japanese-v2. This model is trained on mMARCO Japanese dataset.
from transformers import AutoModelForMaskedLM,AutoTokenizer
import torch
import numpy as np
model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanese")
tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanese")
query = "筑波大学では何の研究が行われているか?"
def encode_query(query, tokenizer, model):
encoded_input = tokenizer(query, return_tensors="pt")
with torch.no_grad():
output = model(**encoded_input, return_dict=True).logits
aggregated_output, _ = torch.max(torch.log(1 + torch.relu(output)) * encoded_input['attention_mask'].unsqueeze(-1), dim=1)
return aggregated_output
def get_topk_tokens(reps, vocab_dict, topk):
topk_values, topk_indices = torch.topk(reps, topk, dim=1)
values = np.rint(topk_values.numpy() * 100).astype(int)
dict_splade = {vocab_dict[id_token.item()]: int(value_token) for id_token, value_token in zip(topk_indices[0], values[0]) if value_token > 0}
return dict_splade
vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
topk = len(vocab_dict) // 1000
model_output = encode_query(query, tokenizer, model)
dict_splade = get_topk_tokens(model_output, vocab_dict, topk)
for token, value in dict_splade.items():
print(token, value)
output
に 250
が 248
は 247
の 247
、 244
と 240
を 239
。 239
も 238
で 237
から 221
や 219
な 206
筑波 204
( 204
・ 202
て 197
へ 191
にて 189
など 188
) 186
まで 184
た 182
この 171
- 170
「 170
より 166
その 165
: 163
」 161
- Downloads last month
- 8
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.