Katsumata420 commited on
Commit
4dcdff9
·
verified ·
1 Parent(s): de9ad40

Update modeling_retrieva_bert.py

Browse files
Files changed (1) hide show
  1. modeling_retrieva_bert.py +7 -7
modeling_retrieva_bert.py CHANGED
@@ -65,7 +65,7 @@ from .configuration_retrieva_bert import RetrievaBertConfig
65
  logger = logging.get_logger(__name__)
66
 
67
  _CONFIG_FOR_DOC = "RetrievaBertConfig"
68
- _CHECKPOINT_FOR_DOC = "nvidia/megatron-bert-cased-345m"
69
 
70
 
71
  def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
@@ -1170,8 +1170,8 @@ class RetrievaBertForPreTraining(RetrievaBertPreTrainedModel):
1170
  >>> from models import RetrievaBertForPreTraining
1171
  >>> import torch
1172
 
1173
- >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
1174
- >>> model = RetrievaBertForPreTraining.from_pretrained("nvidia/megatron-bert-cased-345m")
1175
 
1176
  >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1177
  >>> outputs = model(**inputs)
@@ -1294,8 +1294,8 @@ class RetrievaBertForCausalLM(RetrievaBertPreTrainedModel):
1294
  >>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
1295
  >>> import torch
1296
 
1297
- >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
1298
- >>> model = RetrievaBertForCausalLM.from_pretrained("nvidia/megatron-bert-cased-345m", is_decoder=True)
1299
 
1300
  >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1301
  >>> outputs = model(**inputs)
@@ -1528,8 +1528,8 @@ class RetrievaBertForNextSentencePrediction(RetrievaBertPreTrainedModel):
1528
  >>> from models import RetrievaBertForNextSentencePrediction
1529
  >>> import torch
1530
 
1531
- >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
1532
- >>> model = RetrievaBertForNextSentencePrediction.from_pretrained("nvidia/megatron-bert-cased-345m")
1533
 
1534
  >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1535
  >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
 
65
  logger = logging.get_logger(__name__)
66
 
67
  _CONFIG_FOR_DOC = "RetrievaBertConfig"
68
+ _CHECKPOINT_FOR_DOC = "retrieva-jp/bert-1.3b"
69
 
70
 
71
  def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
 
1170
  >>> from models import RetrievaBertForPreTraining
1171
  >>> import torch
1172
 
1173
+ >>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
1174
+ >>> model = RetrievaBertForPreTraining.from_pretrained("retrieva-jp/bert-1.3b")
1175
 
1176
  >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1177
  >>> outputs = model(**inputs)
 
1294
  >>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
1295
  >>> import torch
1296
 
1297
+ >>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
1298
+ >>> model = RetrievaBertForCausalLM.from_pretrained("retrieva-jp/bert-1.3b", is_decoder=True)
1299
 
1300
  >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1301
  >>> outputs = model(**inputs)
 
1528
  >>> from models import RetrievaBertForNextSentencePrediction
1529
  >>> import torch
1530
 
1531
+ >>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
1532
+ >>> model = RetrievaBertForNextSentencePrediction.from_pretrained("retrieva-jp/bert-1.3b")
1533
 
1534
  >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1535
  >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."