File size: 630 Bytes
510d708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import AutoModelForCausalLM, AutoTokenizer

class HelloWorldModel(AutoModelForCausalLM):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids, **kwargs):
        return {"logits": input_ids}

def generate_text(prompt):
    model = HelloWorldModel.from_pretrained(".")
    tokenizer = AutoTokenizer.from_pretrained(".")
    
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs)
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

if __name__ == "__main__":
    prompt = "hello"
    print(generate_text(prompt))