bevelapi / main.py
BeveledCube's picture
Update main.py
04ca331 verified
raw
history blame
1.38 kB
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "facebook/blenderbot-1B-distill"
# https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
# facebook/blenderbot-3B
# facebook/blenderbot-1B-distill
# facebook/blenderbot-400M-distill
# facebook/blenderbot-90M
# facebook/blenderbot_small-90M
# https://www.youtube.com/watch?v=irjYqV6EebU
app = FastAPI()
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class req(BaseModel):
prompt: str
history: list
@app.get("/")
def read_root():
return FileResponse(path="templates/index.html", media_type="text/html")
@app.post("/api")
def read_root(data: req):
print("Prompt:", data.prompt)
print("History:", data.history)
history_string = "\n".join(data.history)
input_text = data.prompt
# Tokenize the input text
inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
# Generate output using the model
outputs = model.generate(**inputs)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data