|
import os, time, torch, warnings |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
class Inference(): |
|
|
|
def __init__(self, silent=False) -> None: |
|
start_time = time.perf_counter() |
|
self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") |
|
self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState")) |
|
self.model.eval() |
|
if not silent: |
|
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") |
|
|
|
def local_file_path(self, path): |
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) |
|
|
|
def generate(self, prompt, max_length=2000, temperature=0.5, do_sample=True, stop_token=None, callback=None, silent=True): |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
start_time = time.perf_counter() |
|
input_ids = self.tokenizer.encode(prompt, return_tensors='pt') |
|
generated_text = input_ids |
|
while generated_text.shape[1] < max_length: |
|
length = min(50, max_length - generated_text.shape[1]) |
|
with torch.no_grad(): |
|
outputs = self.model.generate(input_ids, max_length=length, temperature=temperature, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id) |
|
new_tokens = outputs[0][-length:] |
|
if callback is not None: |
|
for token in new_tokens: |
|
callback(self.tokenizer.decode([token])) |
|
generated_text = torch.cat((generated_text, new_tokens.unsqueeze(0)), dim=-1) |
|
input_ids = new_tokens.unsqueeze(0) |
|
if stop_token is not None and stop_token in self.tokenizer.decode(generated_text[0]): |
|
break |
|
if not silent: |
|
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") |
|
return self.tokenizer.decode(generated_text[0], skip_special_tokens=True) |
|
|
|
Inference = Inference() |
|
|
|
def spec(stre): |
|
print(stre, end="") |
|
|
|
if __name__=="__main__": |
|
while True: |
|
print(Inference.generate(input(">>> "), max_length=100, temperature=0.8, silent=True)) |