shibing624 commited on
Commit
9820460
1 Parent(s): 06088e1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -1
README.md CHANGED
@@ -40,12 +40,14 @@ print(i)
40
  import operator
41
  import torch
42
  from transformers import BertTokenizer, BertForMaskedLM
 
43
 
44
  tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
45
  model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
 
46
 
47
  texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
48
- outputs = model(**tokenizer(texts, padding=True, return_tensors='pt'))
49
 
50
  def get_errors(corrected_text, origin_text):
51
  details = []
 
40
  import operator
41
  import torch
42
  from transformers import BertTokenizer, BertForMaskedLM
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
 
45
  tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
46
  model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
47
+ model = model.to(device)
48
 
49
  texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
50
+ outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
51
 
52
  def get_errors(corrected_text, origin_text):
53
  details = []