Text Classification
Adapters
English
File size: 753 Bytes
1cb9a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

def load_model(model_name='gpt2-medium'):
    """
    Load GPT-2 Model
    """
    model = GPT2LMHeadModel.from_pretrained(model_name)
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    return model, tokenizer


def generate_text(input_str, model, tokenizer, length=50, temperature=1.0):
    """
    Generate text from a pre-trained GPT-2 model.
    """
    input_ids = tokenizer.encode(input_str, return_tensors='pt')
    out = model.generate(input_ids, max_length=length, temperature=temperature)
    out_decoded = tokenizer.decode(out<0>)
    return out_decoded

model, tokenizer = load_model('gpt2-medium')
print(generate_text("Once upon a time, ", model, tokenizer))