|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
def load_model(model_name='gpt2-medium'): |
|
""" |
|
Load GPT-2 Model |
|
""" |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
|
|
def generate_text(input_str, model, tokenizer, length=50, temperature=1.0): |
|
""" |
|
Generate text from a pre-trained GPT-2 model. |
|
""" |
|
input_ids = tokenizer.encode(input_str, return_tensors='pt') |
|
out = model.generate(input_ids, max_length=length, temperature=temperature) |
|
out_decoded = tokenizer.decode(out<0>) |
|
return out_decoded |
|
|
|
model, tokenizer = load_model('gpt2-medium') |
|
print(generate_text("Once upon a time, ", model, tokenizer)) |
|
|