File size: 753 Bytes
1cb9a8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def load_model(model_name='gpt2-medium'):
"""
Load GPT-2 Model
"""
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
return model, tokenizer
def generate_text(input_str, model, tokenizer, length=50, temperature=1.0):
"""
Generate text from a pre-trained GPT-2 model.
"""
input_ids = tokenizer.encode(input_str, return_tensors='pt')
out = model.generate(input_ids, max_length=length, temperature=temperature)
out_decoded = tokenizer.decode(out<0>)
return out_decoded
model, tokenizer = load_model('gpt2-medium')
print(generate_text("Once upon a time, ", model, tokenizer))
|