bevelapi / models /gpt2.py
BeveledCube's picture
Added EOS toke stuff increased new token limit and added QOL features to frontent
8e724ea
raw
history blame
687 Bytes
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
import tensorflow as tf
model_name = "gpt2"
def load():
global model
global tokenizer
model = TFGPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
def generate(input_text):
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True)
attention_mask = tf.ones_like(input_ids)
# Generate output using the model
output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100, eos_token_id=tokenizer.eos_token_id)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)