HumorGPT / Inference.py
TheAutonomous's picture
Upload 4 files
0f9b91a
import os, time, torch, warnings
from transformers import GPT2LMHeadModel, GPT2Tokenizer
class Inference():
def __init__(self, silent=False) -> None:
start_time = time.perf_counter()
self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState"))
self.model.eval()
if not silent:
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds")
def local_file_path(self, path):
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
def generate(self, prompt, max_length=2000, temperature=0.5, do_sample=True, stop_token=None, callback=None, silent=True):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
start_time = time.perf_counter()
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
generated_text = input_ids
while generated_text.shape[1] < max_length:
length = min(50, max_length - generated_text.shape[1])
with torch.no_grad():
outputs = self.model.generate(input_ids, max_length=length, temperature=temperature, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id)
new_tokens = outputs[0][-length:]
if callback is not None:
for token in new_tokens:
callback(self.tokenizer.decode([token]))
generated_text = torch.cat((generated_text, new_tokens.unsqueeze(0)), dim=-1)
input_ids = new_tokens.unsqueeze(0)
if stop_token is not None and stop_token in self.tokenizer.decode(generated_text[0]):
break
if not silent:
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds")
return self.tokenizer.decode(generated_text[0], skip_special_tokens=True)
Inference = Inference()
def spec(stre):
print(stre, end="")
if __name__=="__main__":
while True:
print(Inference.generate(input(">>> "), max_length=100, temperature=0.8, silent=True))