#! -*- coding: utf-8 -*-
# Naive Bayes-based Context Extension (NBCE)
# 使用朴素贝叶斯增加LLM的Context处理长度
# 链接:https://kexue.fm/archives/9617
# Torch 2.0 测试通过

import json
import torch
from transformers import AutoTokenizer
from transformers import AquilaForCausalLM
from transformers import TopPLogitsWarper, LogitsProcessorList
import pdb

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.padding_side = 'left' 
tokenizer.pad_token = tokenizer.unk_token

# 加载Aquila模型
model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
device = torch.device('cuda')
model.to(device)
# 加载示例Context
from cyg_conversation import default_conversation

conv = default_conversation.copy()
contexts = json.load(open('code_text_2.json'))

question = "请解释这段程序的功能:"
batch = []
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
batch.append(conv.get_prompt())
# 拼接context和question
for ci,context in enumerate(contexts):        
    conv1 = default_conversation.copy()
    conv1.append_message(conv.roles[0], context+question)
    conv1.append_message(conv.roles[1], None)
    batch.append(conv1.get_prompt())
print('Context长度分布:', [len(text) for text in batch])
print('Context总长度:', sum([len(text) for text in batch]))

# Top-P截断
processors = LogitsProcessorList()
processors.append(TopPLogitsWarper(0.95))

# Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
@torch.inference_mode()
def generate(max_tokens):
    """Naive Bayes-based Context Extension 演示代码
    """
    inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    
    print('input_ids', input_ids.shape)
    past_key_values = None
    n = input_ids.shape[0]
    
    for i in range(max_tokens):
        # 模型输出
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        return_dict=True,
                        use_cache=True,
                        past_key_values=past_key_values
                       )
        past_key_values = outputs.past_key_values
        
        # ===== 核心代码开始 =====
        beta, eta = 0.25, 0.1
        logits = outputs.logits[:, -1]
        logits = logits - logits.logsumexp(dim=-1, keepdims=True)
        logits = processors(input_ids, logits)
        entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
        if i > 0:
            entropy[k] -= eta
        k = entropy[1:].argmin() + 1
        logits_max = logits[k]
        logits_uncond = logits[0]
        logits_merged = (1 + beta) * logits_max - beta * logits_uncond
        logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
        # ===== 核心代码结束 =====
        
        # 构建分布,采样
        # tau = 1是标准的随机采样,tau->0则是贪心搜索
        # 简单起见,这里没有实现topk、topp截断
        tau = 0.01
        probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
        next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)        
        if next_tokens[0] == tokenizer.eos_token_id:
            break
            
        ret = tokenizer.batch_decode(next_tokens)
        print(ret[0], flush=True, end='')
        
        # prepare for next iteration
        input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
        attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)        


if __name__ == '__main__':
    generate(1000)


"""
========= 输出结果参考 =========

1.菲律宾国家电网公司,中国占股多少?
答:中国国家电网公司持有菲律宾国家电网公司40%的股份。

2.领英计划裁员多少人?
答:领英计划裁员716人。

3.吉利德收购Pharmasset的价格是多少?
答:吉利德收购Pharmasset的价格为110亿美元。

4.丙肝神药Sovaldi在哪一年上市?
答:丙肝神药Sovaldi于2013年上市。

5.中亚峰会将在哪里举行?由谁主持?
答:中亚峰会将在陕西省西安市举行,由国家主席习近平主持。

6.哪个演员由于侮辱人民军队而被立案调查?
答:李昊石因在表演中存在侮辱人民军队的言论而被立案调查。

7.哪个项目宣称“能过坦克”的水上道路?
答:湖北恩施宣称的“能过坦克”水上道路。

8.如果你是默沙东的CEO,你的首要任务是什么?
答:如果我是默沙东的CEO,我的首要任务是如何让基本盘更加坚固,并通过药物联用获得更好的增长。
"""