Spaces:
Running
Running
File size: 1,428 Bytes
bf5c1c9 e1b04d1 bf5c1c9 bc27fb1 90151e1 e3db2f8 90151e1 e1b04d1 90151e1 5517f9c e3db2f8 bf5c1c9 e1b04d1 207c16a bf5c1c9 e3db2f8 5517f9c bf5c1c9 c36c2b7 6e0a07a bf5c1c9 6c9982e 207c16a e3db2f8 9920987 e3db2f8 9920987 74272ac 9920987 1ec2cf1 e3db2f8 9920987 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "facebook/blenderbot-1B-distill"
# https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
# facebook/blenderbot-3B
# facebook/blenderbot-1B-distill
# facebook/blenderbot-400M-distill
# facebook/blenderbot-90M
# facebook/blenderbot_small-90M
# https://www.youtube.com/watch?v=irjYqV6EebU
app = FastAPI()
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class req(BaseModel):
prompt: str
history: list
@app.get("/")
def read_root():
return FileResponse(path="templates/index.html", media_type="text/html")
@app.post("/api")
def read_root(data: req):
print("Prompt:", data.prompt)
print("History:", data.history)
history_string = "\n".join(data.history)
input_text = data.prompt
# Tokenize the input text
input_ids = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
# Generate output using the model
output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data |