BeveledCube commited on
Commit
e3db2f8
1 Parent(s): bc27fb1

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -4
main.py CHANGED
@@ -7,19 +7,22 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
 
8
  model_name = "facebook/blenderbot-1B-distill"
9
 
10
- # https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbo
11
  # facebook/blenderbot-3B
12
  # facebook/blenderbot-1B-distill
13
  # facebook/blenderbot-400M-distill
14
  # facebook/blenderbot-90M
15
  # facebook/blenderbot_small-90M
16
 
 
 
17
  app = FastAPI()
18
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
  class req(BaseModel):
22
  prompt: str
 
23
 
24
  @app.get("/")
25
  def read_root():
@@ -29,14 +32,16 @@ def read_root():
29
  def read_root(data: req):
30
  print("Prompt:", data.prompt)
31
 
 
 
32
  input_text = data.prompt
33
-
34
  # Tokenize the input text
35
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
36
 
37
  # Generate output using the model
38
  output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
39
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
 
41
  answer_data = { "answer": generated_text }
42
  print("Answer:", generated_text)
 
7
 
8
  model_name = "facebook/blenderbot-1B-distill"
9
 
10
+ # https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
11
  # facebook/blenderbot-3B
12
  # facebook/blenderbot-1B-distill
13
  # facebook/blenderbot-400M-distill
14
  # facebook/blenderbot-90M
15
  # facebook/blenderbot_small-90M
16
 
17
+ # https://www.youtube.com/watch?v=irjYqV6EebU
18
+
19
  app = FastAPI()
20
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
 
23
  class req(BaseModel):
24
  prompt: str
25
+ history: list
26
 
27
  @app.get("/")
28
  def read_root():
 
32
  def read_root(data: req):
33
  print("Prompt:", data.prompt)
34
 
35
+ history_string = "\n".join(data.history)
36
+
37
  input_text = data.prompt
38
+
39
  # Tokenize the input text
40
+ input_ids = tokenizer.encode(history_string, input_text, return_tensors="pt")
41
 
42
  # Generate output using the model
43
  output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
44
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
45
 
46
  answer_data = { "answer": generated_text }
47
  print("Answer:", generated_text)