File size: 1,389 Bytes
bf5c1c9
 
 
 
 
e1b04d1
bf5c1c9
bc27fb1
90151e1
e3db2f8
90151e1
 
e1b04d1
90151e1
 
5517f9c
e3db2f8
 
bf5c1c9
e1b04d1
 
207c16a
bf5c1c9
 
e3db2f8
5517f9c
bf5c1c9
 
 
c36c2b7
6e0a07a
bf5c1c9
 
207c16a
e3db2f8
 
9920987
e3db2f8
9920987
e3db2f8
9920987
 
1ec2cf1
e3db2f8
9920987
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "facebook/blenderbot-1B-distill"

# https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
# facebook/blenderbot-3B
# facebook/blenderbot-1B-distill
# facebook/blenderbot-400M-distill
# facebook/blenderbot-90M
# facebook/blenderbot_small-90M

# https://www.youtube.com/watch?v=irjYqV6EebU

app = FastAPI()
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

class req(BaseModel):
  prompt: str
  history: list

@app.get("/")
def read_root():
  return FileResponse(path="templates/index.html", media_type="text/html")

@app.post("/api")
def read_root(data: req):
  print("Prompt:", data.prompt)

  history_string = "\n".join(data.history)
  
  input_text = data.prompt
  
  # Tokenize the input text
  input_ids = tokenizer.encode(history_string, input_text, return_tensors="pt")
  
  # Generate output using the model
  output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
  
  answer_data = { "answer": generated_text }
  print("Answer:", generated_text)
  
  return answer_data