BAAI
/

shunxing1234 commited on
Commit
6980f4b
·
1 Parent(s): 1357761

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +0 -104
README.md CHANGED
@@ -70,110 +70,6 @@ with torch.no_grad():
70
  print(out)
71
  ```
72
 
73
- usning [NBCE](https://github.com/bojone/NBCE/tree/main) Inference
74
-
75
- ```python
76
- import json, jsonlines
77
- import torch
78
- from transformers import AutoTokenizer
79
- from transformers import AquilaForCausalLM
80
- from transformers import TopPLogitsWarper, LogitsProcessorList
81
-
82
- from cyg_conversation import default_conversation
83
- def preprocess(text, question="回答:"):
84
- tmp=""
85
- import json
86
- contexts = []
87
- conv = default_conversation.copy()
88
- conv.append_message(conv.roles[0], ""+question)
89
- conv.append_message(conv.roles[1], None)
90
- contexts.append(conv.get_prompt())
91
- for pos in range(0,len(text),1024):
92
- conv1 = default_conversation.copy()
93
- conv1.append_message(conv1.roles[0], text[pos:min(pos + 1024, len(text))] + question)
94
- conv1.append_message(conv1.roles[1], None)
95
- contexts.append(conv1.get_prompt())
96
- print('Context长度分布:', [len(text) for text in contexts])
97
- print('Context总长度:', sum([len(text) for text in contexts]))
98
- return contexts
99
-
100
- # load tokenizer
101
- model_path = "checkpoints/hf_weight"
102
- tokenizer = AutoTokenizer.from_pretrained(model_path)
103
- tokenizer.padding_side = 'left'
104
- tokenizer.pad_token = tokenizer.unk_token
105
-
106
- # load Aquila model
107
- model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
108
- device = torch.device('cuda')
109
- model.to(device)
110
-
111
- # Top-P
112
- processors = LogitsProcessorList()
113
- processors.append(TopPLogitsWarper(0.95))
114
-
115
- # Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
116
- @torch.inference_mode()
117
- def generate(max_tokens, batch):
118
- """Naive Bayes-based Context Extension
119
- """
120
- inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
121
- input_ids = inputs.input_ids
122
- attention_mask = inputs.attention_mask
123
-
124
- #print('input_ids', input_ids.shape)
125
- past_key_values = None
126
- n = input_ids.shape[0]
127
-
128
- for i in range(max_tokens):
129
- # model output
130
- outputs = model(input_ids=input_ids,
131
- attention_mask=attention_mask,
132
- return_dict=True,
133
- use_cache=True,
134
- past_key_values=past_key_values
135
- )
136
- past_key_values = outputs.past_key_values
137
-
138
- # ===== NBCE core code starts =====
139
- beta, eta = 0.25, 0.1
140
- logits = outputs.logits[:, -1]
141
- logits = logits - logits.logsumexp(dim=-1, keepdims=True)
142
- logits = processors(input_ids, logits)
143
- entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
144
- if i > 0:
145
- entropy[k] -= eta
146
- k = entropy[1:].argmin() + 1
147
- logits_max = logits[k]
148
- logits_uncond = logits[0]
149
- logits_merged = (1 + beta) * logits_max - beta * logits_uncond
150
- logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
151
- # ===== NBCE core code ends =====
152
-
153
- # Building a distribution and sampling
154
- # tau = 1 is standard random sampling,tau->0 is greedy search
155
- # For simplicity, top-k and top-p truncation are not implemented here.
156
- tau = 0.01
157
- probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
158
- next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
159
- if next_tokens[0] == tokenizer.eos_token_id:
160
- break
161
-
162
- ret = tokenizer.batch_decode(next_tokens)
163
- print(ret[0], flush=True, end='')
164
-
165
- # prepare for next iteration
166
- input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
167
- attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)
168
-
169
-
170
- if __name__ == '__main__':
171
- count = 0
172
- with open("gaokao_chinese_dataset.jsonl",'r') as f:
173
- for item in jsonlines.Reader(f):
174
- batch = preprocess(item['prompt'],question=item['question'])
175
- generate(10, batch)
176
- ```
177
 
178
  ## License
179
 
 
70
  print(out)
71
  ```
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  ## License
75