File size: 2,181 Bytes
0f9b91a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os, time, torch, warnings
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class Inference():

    def __init__(self, silent=False) -> None:
        start_time = time.perf_counter()
        self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
        self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState"))
        self.model.eval()
        if not silent:
            print(f"Model Loading Took {time.perf_counter()-start_time} Seconds")

    def local_file_path(self, path):
        return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)

    def generate(self, prompt, max_length=2000, temperature=0.5, do_sample=True, stop_token=None, callback=None, silent=True):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            start_time = time.perf_counter()
            input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
            generated_text = input_ids
            while generated_text.shape[1] < max_length:
                length = min(50, max_length - generated_text.shape[1])
                with torch.no_grad():
                    outputs = self.model.generate(input_ids, max_length=length, temperature=temperature, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id)
                new_tokens = outputs[0][-length:]
                if callback is not None:
                    for token in new_tokens:
                        callback(self.tokenizer.decode([token]))
                generated_text = torch.cat((generated_text, new_tokens.unsqueeze(0)), dim=-1)
                input_ids = new_tokens.unsqueeze(0)
                if stop_token is not None and stop_token in self.tokenizer.decode(generated_text[0]):
                    break
            if not silent:
                print(f"Model Loading Took {time.perf_counter()-start_time} Seconds")
            return self.tokenizer.decode(generated_text[0], skip_special_tokens=True)

Inference = Inference()

def spec(stre):
    print(stre, end="")

if __name__=="__main__":
    while True:
        print(Inference.generate(input(">>> "), max_length=100, temperature=0.8, silent=True))