Spaces:
Running
Running
Upload main.py
Browse files
main.py
CHANGED
@@ -7,19 +7,22 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
7 |
|
8 |
model_name = "facebook/blenderbot-1B-distill"
|
9 |
|
10 |
-
# https://huggingface.co/models?sort=trending&search=facebook%
|
11 |
# facebook/blenderbot-3B
|
12 |
# facebook/blenderbot-1B-distill
|
13 |
# facebook/blenderbot-400M-distill
|
14 |
# facebook/blenderbot-90M
|
15 |
# facebook/blenderbot_small-90M
|
16 |
|
|
|
|
|
17 |
app = FastAPI()
|
18 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
|
21 |
class req(BaseModel):
|
22 |
prompt: str
|
|
|
23 |
|
24 |
@app.get("/")
|
25 |
def read_root():
|
@@ -29,14 +32,16 @@ def read_root():
|
|
29 |
def read_root(data: req):
|
30 |
print("Prompt:", data.prompt)
|
31 |
|
|
|
|
|
32 |
input_text = data.prompt
|
33 |
-
|
34 |
# Tokenize the input text
|
35 |
-
input_ids = tokenizer.encode(input_text, return_tensors="pt")
|
36 |
|
37 |
# Generate output using the model
|
38 |
output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
|
39 |
-
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
40 |
|
41 |
answer_data = { "answer": generated_text }
|
42 |
print("Answer:", generated_text)
|
|
|
7 |
|
8 |
model_name = "facebook/blenderbot-1B-distill"
|
9 |
|
10 |
+
# https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
|
11 |
# facebook/blenderbot-3B
|
12 |
# facebook/blenderbot-1B-distill
|
13 |
# facebook/blenderbot-400M-distill
|
14 |
# facebook/blenderbot-90M
|
15 |
# facebook/blenderbot_small-90M
|
16 |
|
17 |
+
# https://www.youtube.com/watch?v=irjYqV6EebU
|
18 |
+
|
19 |
app = FastAPI()
|
20 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
|
23 |
class req(BaseModel):
|
24 |
prompt: str
|
25 |
+
history: list
|
26 |
|
27 |
@app.get("/")
|
28 |
def read_root():
|
|
|
32 |
def read_root(data: req):
|
33 |
print("Prompt:", data.prompt)
|
34 |
|
35 |
+
history_string = "\n".join(data.history)
|
36 |
+
|
37 |
input_text = data.prompt
|
38 |
+
|
39 |
# Tokenize the input text
|
40 |
+
input_ids = tokenizer.encode(history_string, input_text, return_tensors="pt")
|
41 |
|
42 |
# Generate output using the model
|
43 |
output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
|
44 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
45 |
|
46 |
answer_data = { "answer": generated_text }
|
47 |
print("Answer:", generated_text)
|