test / app.py
Chengxb888's picture
Update app.py
576bbe0 verified
raw
history blame
721 Bytes
from fastapi import FastAPI
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.get("/hello/{msg}")
def say_hello(msg: str):
print("model")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
device_map="auto",
torch_dtype=torch.bfloat16
)
print("token & msg")
input_ids = tokenizer(msg, return_tensors="pt").to("cpu")
print("output")
outputs = model.generate(**input_ids, max_length=500)
print("complete")
return {"message": tokenizer.decode(outputs[0])}