squarelike commited on
Commit
ac5886f
1 Parent(s): 058dd06

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -6
README.md CHANGED
@@ -54,7 +54,7 @@ class StoppingCriteriaSub(StoppingCriteria):
54
 
55
  return False
56
 
57
- stop_words_ids = torch.tensor([[829, 45107, 29958], [1533, 45107, 29958], [829, 45107, 29958], [21106, 45107, 29958]])
58
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
59
 
60
  def gen(lan="en", x=""):
@@ -67,13 +67,10 @@ def gen(lan="en", x=""):
67
  prompt,
68
  return_tensors='pt',
69
  return_token_type_ids=False
70
- ),
71
- max_new_tokens=1000,
72
  temperature=0.1,
73
- no_repeat_ngram_size=10,
74
- early_stopping=True,
75
  do_sample=True,
76
- eos_token_id=2,
77
  stopping_criteria=stopping_criteria
78
  )
79
  return tokenizer.decode(gened[0][1:]).replace(prompt+" ", "").replace("</끝>", "")
 
54
 
55
  return False
56
 
57
+ stop_words_ids = torch.tensor([[829, 45107, 29958], [1533, 45107, 29958], [829, 45107, 29958], [21106, 45107, 29958]]).to("cuda")
58
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
59
 
60
  def gen(lan="en", x=""):
 
67
  prompt,
68
  return_tensors='pt',
69
  return_token_type_ids=False
70
+ ).to("cuda"),
71
+ max_new_tokens=2000,
72
  temperature=0.1,
 
 
73
  do_sample=True,
 
74
  stopping_criteria=stopping_criteria
75
  )
76
  return tokenizer.decode(gened[0][1:]).replace(prompt+" ", "").replace("</끝>", "")