File size: 698 Bytes
c5cfb64
9ce6c57
707e859
e08a4cf
 
 
 
707e859
e08a4cf
 
 
 
 
707e859
e08a4cf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

from transformers import GPT2LMHeadModel, GPT2Tokenizer

class EinfachPrompt:
    def __init__(self):
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.model = GPT2LMHeadModel.from_pretrained("gpt2")

    def generate(self, prompt):
        inputs = self.tokenizer.encode(prompt, return_tensors="pt")
        outputs = self.model.generate(inputs, max_length=150, num_return_sequences=1, temperature=0.7)
        generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated

if __name__ == "__main__":
    einfach_prompt = EinfachPrompt()
    prompt = "Erzähl mir etwas über EinfachPrompt."
    print(einfach_prompt.generate(prompt))