bevelapi / main.py
BeveledCube's picture
Upload main.py
e3db2f8 verified
raw
history blame
1.39 kB
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "facebook/blenderbot-1B-distill"
# https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
# facebook/blenderbot-3B
# facebook/blenderbot-1B-distill
# facebook/blenderbot-400M-distill
# facebook/blenderbot-90M
# facebook/blenderbot_small-90M
# https://www.youtube.com/watch?v=irjYqV6EebU
app = FastAPI()
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class req(BaseModel):
prompt: str
history: list
@app.get("/")
def read_root():
return FileResponse(path="templates/index.html", media_type="text/html")
@app.post("/api")
def read_root(data: req):
print("Prompt:", data.prompt)
history_string = "\n".join(data.history)
input_text = data.prompt
# Tokenize the input text
input_ids = tokenizer.encode(history_string, input_text, return_tensors="pt")
# Generate output using the model
output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data