BAAI
/

shunxing1234 commited on
Commit
15e92fc
·
1 Parent(s): d910332

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -12
README.md CHANGED
@@ -80,16 +80,16 @@ from transformers import AutoModelForCausalLM
80
  from transformers import TopPLogitsWarper, LogitsProcessorList
81
  import pdb
82
 
83
- # 加载tokenizer
84
  tokenizer = AutoTokenizer.from_pretrained(model_path)
85
  tokenizer.padding_side = 'left'
86
  tokenizer.pad_token = tokenizer.unk_token
87
 
88
- # 加载Aquila模型
89
  model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
90
  device = torch.device('cuda')
91
  model.to(device)
92
- # 加载示例Context
93
  from cyg_conversation import default_conversation
94
 
95
  conv = default_conversation.copy()
@@ -100,7 +100,7 @@ batch = []
100
  conv.append_message(conv.roles[0], question)
101
  conv.append_message(conv.roles[1], None)
102
  batch.append(conv.get_prompt())
103
- # 拼接contextquestion
104
  for ci,context in enumerate(contexts):
105
  conv1 = default_conversation.copy()
106
  conv1.append_message(conv.roles[0], context+question)
@@ -109,14 +109,14 @@ for ci,context in enumerate(contexts):
109
  print('Context长度分布:', [len(text) for text in batch])
110
  print('Context总长度:', sum([len(text) for text in batch]))
111
 
112
- # Top-P截断
113
  processors = LogitsProcessorList()
114
  processors.append(TopPLogitsWarper(0.95))
115
 
116
  # Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
117
  @torch.inference_mode()
118
  def generate(max_tokens):
119
- """Naive Bayes-based Context Extension 演示代码
120
  """
121
  inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
122
  input_ids = inputs.input_ids
@@ -127,7 +127,7 @@ def generate(max_tokens):
127
  n = input_ids.shape[0]
128
 
129
  for i in range(max_tokens):
130
- # 模型输出
131
  outputs = model(input_ids=input_ids,
132
  attention_mask=attention_mask,
133
  return_dict=True,
@@ -136,7 +136,7 @@ def generate(max_tokens):
136
  )
137
  past_key_values = outputs.past_key_values
138
 
139
- # ===== 核心代码开始 =====
140
  beta, eta = 0.25, 0.1
141
  logits = outputs.logits[:, -1]
142
  logits = logits - logits.logsumexp(dim=-1, keepdims=True)
@@ -149,11 +149,11 @@ def generate(max_tokens):
149
  logits_uncond = logits[0]
150
  logits_merged = (1 + beta) * logits_max - beta * logits_uncond
151
  logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
152
- # ===== 核心代码结束 =====
153
 
154
- # 构建分布,采样
155
- # tau = 1是标准的随机采样,tau->0则是贪心搜索
156
- # 简单起见,这里没有实现topk、topp截断
157
  tau = 0.01
158
  probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
159
  next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
 
80
  from transformers import TopPLogitsWarper, LogitsProcessorList
81
  import pdb
82
 
83
+ # load tokenizer
84
  tokenizer = AutoTokenizer.from_pretrained(model_path)
85
  tokenizer.padding_side = 'left'
86
  tokenizer.pad_token = tokenizer.unk_token
87
 
88
+ # load Aquila model
89
  model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
90
  device = torch.device('cuda')
91
  model.to(device)
92
+ # load example Context
93
  from cyg_conversation import default_conversation
94
 
95
  conv = default_conversation.copy()
 
100
  conv.append_message(conv.roles[0], question)
101
  conv.append_message(conv.roles[1], None)
102
  batch.append(conv.get_prompt())
103
+ # concat context and question
104
  for ci,context in enumerate(contexts):
105
  conv1 = default_conversation.copy()
106
  conv1.append_message(conv.roles[0], context+question)
 
109
  print('Context长度分布:', [len(text) for text in batch])
110
  print('Context总长度:', sum([len(text) for text in batch]))
111
 
112
+ # Top-P
113
  processors = LogitsProcessorList()
114
  processors.append(TopPLogitsWarper(0.95))
115
 
116
  # Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
117
  @torch.inference_mode()
118
  def generate(max_tokens):
119
+ """Naive Bayes-based Context Extension example code
120
  """
121
  inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
122
  input_ids = inputs.input_ids
 
127
  n = input_ids.shape[0]
128
 
129
  for i in range(max_tokens):
130
+ # model output
131
  outputs = model(input_ids=input_ids,
132
  attention_mask=attention_mask,
133
  return_dict=True,
 
136
  )
137
  past_key_values = outputs.past_key_values
138
 
139
+ # ===== NBCE core code starts =====
140
  beta, eta = 0.25, 0.1
141
  logits = outputs.logits[:, -1]
142
  logits = logits - logits.logsumexp(dim=-1, keepdims=True)
 
149
  logits_uncond = logits[0]
150
  logits_merged = (1 + beta) * logits_max - beta * logits_uncond
151
  logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
152
+ # ===== NBCE core code ends =====
153
 
154
+ # Building a distribution and sampling
155
+ # tau = 1 is standard random sampling,tau->0 is greedy search
156
+ # For simplicity, top-k and top-p truncation are not implemented here.
157
  tau = 0.01
158
  probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
159
  next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)