File size: 2,698 Bytes
577164e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import argparse
import os
import random
import os
import json
import random, os
import numpy as np
import torch

from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import TextStreamer, GenerationConfig


class LocalStoppingCriteria(StoppingCriteria):

    def __init__(self, tokenizer, stop_words=[]):
        super().__init__()

        stops = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for
                 stop_word in stop_words]
        print('stop_words', stop_words)
        print('stop_words_ids', stops)
        self.stop_words = stop_words
        self.stops = [stop.cuda() for stop in stops]
        self.tokenizer = tokenizer

    def _compare_token(self, input_ids):
        for stop in self.stops:
            if len(stop.size()) != 1:
                continue
            stop_len = len(stop)
            if torch.all((stop == input_ids[0][-stop_len:])).item():
                return True

        return False

    def _compare_decode(self, input_ids):
        input_str = self.tokenizer.decode(input_ids[0])
        for stop_word in self.stop_words:
            if input_str.endswith(stop_word):
                return True
        return False

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        return self._compare_decode(input_ids)


def seed_everything(seed: int):

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def generation(model, tokenizer, x, max_new_tokens=1024):
    
    stopping_criteria = StoppingCriteriaList(
    [LocalStoppingCriteria(tokenizer=tokenizer, stop_words=[tokenizer.eos_token])])
    streamer = TextStreamer(tokenizer)


    generation_config = GenerationConfig(
        temperature=1.0,
        top_p=0.8,
        top_k=100,
        max_new_tokens=max_new_tokens,
        early_stopping=True,
        do_sample=True,
    )
    gened = model.generate(
        **tokenizer(
            x,
            return_tensors='pt',
            return_token_type_ids=False
        ).to('cuda'),
        generation_config=generation_config,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        stopping_criteria=stopping_criteria,
        streamer=streamer,
    )
    response = tokenizer.decode(gened[0])
    only_gen_text = response.split(x)
    if len(only_gen_text) == 2:
        response = only_gen_text[-1]
    response = response.replace(tokenizer.eos_token, '')
    return response