TextGeneration1 / inference.py
Mxytyu's picture
Create inference.py
510d708 verified
raw
history blame contribute delete
630 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
class HelloWorldModel(AutoModelForCausalLM):
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids, **kwargs):
return {"logits": input_ids}
def generate_text(prompt):
model = HelloWorldModel.from_pretrained(".")
tokenizer = AutoTokenizer.from_pretrained(".")
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if __name__ == "__main__":
prompt = "hello"
print(generate_text(prompt))