supermy commited on
Commit
aac4628
·
1 Parent(s): 4be04dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -6,6 +6,8 @@ import torch.nn.functional as F
6
 
7
  from transformers import AutoTokenizer, GPT2LMHeadModel
8
  tokenizer = AutoTokenizer.from_pretrained("supermy/jinyong-gpt2")
 
 
9
  model = GPT2LMHeadModel.from_pretrained("supermy/jinyong-gpt2")
10
  model.eval()
11
 
@@ -27,12 +29,12 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf'
27
 
28
  def generate(title, context, max_len):
29
 
30
- input_ids=tokenizer.encode(title + "-" + context, add_special_tokens=False)
31
 
32
- # title_ids = tokenizer.encode(title, add_special_tokens=False)
33
- # context_ids = tokenizer.encode(context, add_special_tokens=False)
34
- # input_ids = title_ids + [sep_id] + context_ids
35
- # print(input_ids)
36
 
37
  cur_len = len(input_ids)
38
  input_len = cur_len
 
6
 
7
  from transformers import AutoTokenizer, GPT2LMHeadModel
8
  tokenizer = AutoTokenizer.from_pretrained("supermy/jinyong-gpt2")
9
+ tokenizer.add_special_tokens(['SEP'])
10
+ tokenizer.add_special_tokens(['UNK'])
11
  model = GPT2LMHeadModel.from_pretrained("supermy/jinyong-gpt2")
12
  model.eval()
13
 
 
29
 
30
  def generate(title, context, max_len):
31
 
32
+ # input_ids=tokenizer.encode(title + "-" + context, add_special_tokens=False)
33
 
34
+ title_ids = tokenizer.encode(title, add_special_tokens=False)
35
+ context_ids = tokenizer.encode(context, add_special_tokens=False)
36
+ input_ids = title_ids + [sep_id] + context_ids
37
+ print(input_ids)
38
 
39
  cur_len = len(input_ids)
40
  input_len = cur_len