File size: 630 Bytes
510d708 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import AutoModelForCausalLM, AutoTokenizer
class HelloWorldModel(AutoModelForCausalLM):
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids, **kwargs):
return {"logits": input_ids}
def generate_text(prompt):
model = HelloWorldModel.from_pretrained(".")
tokenizer = AutoTokenizer.from_pretrained(".")
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if __name__ == "__main__":
prompt = "hello"
print(generate_text(prompt)) |