File size: 3,690 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# encoding=utf-8
from typing import List, Union

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import T5Tokenizer

from fengshen.models.transfo_xl_reasoning import TransfoXLModel
from fengshen.utils import sample_sequence_batch


def en_to_zh(sentence:str):
    en_pun = u",.!?[]()<>\"\"''"
    zh_pun = u",。!?【】()《》“”‘’"
    table = {
        ord(f): ord(t) for f,t in zip(en_pun, zh_pun)
    }
    return sentence.translate(table)


def deduction_generate(
    model:TransfoXLModel,
    tokenizer:T5Tokenizer,
    input_text:Union[str, List[str]],
    device:int=0,
    batch_size:int=2,
    temperature:float=1.0,
    repetition_penalty:float=2.0,
    max_out_seq:int=512,
    top_p:float=0.6) -> List[str]:
    """ Generate with fixed prompt of deduction """

    model = model.eval().cuda(device)
    
    if isinstance(input_text, str):
        input_text = [input_text]

    input_text = [f"<bos>{text},因而" for text in input_text]

    input_ids = [torch.tensor(ids[:-1]) for ids in tokenizer(input_text).input_ids]
    input_length = [len(ids) for ids in input_ids]

    output = []

    for index in range(0, len(input_ids), batch_size):
        input_ids_batch = pad_sequence(
            input_ids[index: index + batch_size], batch_first=True, padding_value=50000,
        )
        input_ids_length = torch.tensor(input_length[index: index + batch_size])

        res_ids_batch, _ = sample_sequence_batch(
            model=model,
            context_tokens_tensor=input_ids_batch.cuda(device=device),
            context_length_tensor=input_ids_length.cuda(device=device),
            end_token_id=50000,
            top_k=0, top_p=top_p,
            max_out_seq=max_out_seq,
            repetition_penalty=repetition_penalty,
            temperature=temperature
        )

        res_sentence = [
            en_to_zh(tokenizer.decode(ids[length:])).replace(" ", "")
            for ids, length in zip(res_ids_batch, input_length[index: index + batch_size])
        ]

        output.extend(res_sentence)

    return output


def abduction_generate(
    model:TransfoXLModel,
    tokenizer:T5Tokenizer,
    input_text:Union[str, List[str]],
    device:int=0,
    batch_size:int=2,
    temperature:float=1.0,
    repetition_penalty:float=2.0,
    top_p:float=0.6) -> List[str]:
    """ Generate with fixed prompt of abduction """

    model = model.eval().cuda(device)

    if isinstance(input_text, str):
        input_text = [input_text]

    input_text = [f"<bos>之所以{text},是因为" for text in input_text]

    input_ids = [torch.tensor(ids[:-1]) for ids in tokenizer(input_text).input_ids]
    input_length = [len(ids) for ids in input_ids]

    output = []

    for index in range(0, len(input_ids), batch_size):
        input_ids_batch = pad_sequence(
            input_ids[index: index + batch_size], batch_first=True, padding_value=50000,
        )
        input_ids_length = torch.tensor(input_length[index: index + batch_size])

        res_ids_batch, _ = sample_sequence_batch(
            model=model,
            context_tokens_tensor=input_ids_batch.cuda(device=device),
            context_length_tensor=input_ids_length.cuda(device=device),
            end_token_id=50000,
            top_k=0, top_p=top_p,
            max_out_seq=512,
            repetition_penalty=repetition_penalty,
            temperature=temperature
        )

        res_sentence = [
            en_to_zh(tokenizer.decode(ids[length:])).replace(" ", "")
            for ids, length in zip(res_ids_batch, input_length[index: index + batch_size])
        ]

        output.extend(res_sentence)

    return output